class Solution(object):
    def updateMatrix(self, mat):
        """
        :type mat: List[List[int]]
        :rtype: List[List[int]]
        """
        m = len(mat)
        n = len(mat[0])
        now_set = set()
        for i in range(m):
            for j in range(n):
                if mat[i][j]:
                    mat[i][j] = -1
                else:
                    now_set.add((i, j))
        dist = 0
        while now_set:
            dist += 1
            next_set = set()
            for x, y in now_set:
                for ux, uy in [(x - 1, y), (x + 1, y), (x, y - 1), (x, y + 1)]:
                    if 0 <= ux < m and 0 <= uy < n and mat[ux][uy] == -1:
                        mat[ux][uy] = dist
                        next_set.add((ux, uy))
            now_set = next_set
        return mat
